//
// Created by pulsarv on 19-12-8.
//

#include <road_segmentation.h>
#include <base/slog.hpp>
#include <base/common.hpp>
#include <src/common/base/ocv_common.hpp>

namespace MobileSearch {

    RoadSegmentation::RoadSegmentation(InferenceEngine::Core &ie,
                                       const std::string &deviceName,
                                       const std::string &xmlPath,
                                       const bool auto_resize,
                                       const std::map<std::string,
                                               std::string> &pluginConfig,
                                       const std::string &model_segmentation_path
    ) : ie_(ie) {
        slog::info << "Loading Segmentation model files" << slog::endl;

        InferenceEngine::CNNNetReader road_networkReader;
        /** Read Road network model **/
        road_networkReader.ReadNetwork(model_segmentation_path);

        /** Extract model name and load weights **/
        std::string binFileName = fileNameNoExt(model_segmentation_path) + ".bin";
        road_networkReader.ReadWeights(binFileName);
        slog::info << model_segmentation_path << slog::endl;
        slog::info << binFileName << slog::endl;

        InferenceEngine::CNNNetwork network = road_networkReader.getNetwork();
        InferenceEngine::InputsDataMap segmentationInputInfo(road_networkReader.getNetwork().getInputsInfo());

        InferenceEngine::InputInfo::Ptr &segmentationInputInfoFirst = segmentationInputInfo.begin()->second;
        if (segmentationInputInfo.size() != 1)
            throw std::logic_error("Segmentation Input supports topologies only with 1 input");

        segmentationInputInfoFirst->setPrecision(InferenceEngine::Precision::U8);
        if (auto_resize) {
            segmentationInputInfoFirst->getPreProcess().setResizeAlgorithm(
                    InferenceEngine::ResizeAlgorithm::RESIZE_BILINEAR);
            segmentationInputInfoFirst->setLayout(InferenceEngine::Layout::NHWC);
        } else {
            segmentationInputInfoFirst->setLayout(InferenceEngine::Layout::NCHW);
        }


    }

    InferenceEngine::InferRequest RoadSegmentation::createInferRequest() {
        return net.CreateInferRequest();
    }

    void RoadSegmentation::setImage(InferenceEngine::InferRequest &inferRequest, const cv::Mat &img,
                                    const cv::Mat &segmentation_img) {
        InferenceEngine::Blob::Ptr input = inferRequest.GetBlob(segmentationInputBlobName);
        if (InferenceEngine::Layout::NHWC == input->getTensorDesc().getLayout()) {  // autoResize is set
            if (!img.isSubmatrix()) {
                // just wrap Mat object with Blob::Ptr without additional memory allocation
                InferenceEngine::Blob::Ptr frameBlob = wrapMat2Blob(img);
                inferRequest.SetBlob(segmentationInputBlobName, frameBlob);
            } else {
                throw std::logic_error("Sparse matrix are not supported");
            }
        } else {
            matU8ToBlob<uint8_t>(img, input);
        }
    }

    cv::Mat *RoadSegmentation::getResults(InferenceEngine::InferRequest &inferRequest) {
        slog::info << "Processing output blobs" << slog::endl;

        const InferenceEngine::Blob::Ptr output_blob = inferRequest.GetBlob(segmentationOutputBlobName);
        const auto output_data = output_blob->buffer().as<float *>();

        size_t N = output_blob->getTensorDesc().getDims().at(0);
        size_t C, H, W;

        size_t output_blob_shape_size = output_blob->getTensorDesc().getDims().size();
        slog::info << "Output blob has " << output_blob_shape_size << " dimensions" << slog::endl;

        if (output_blob_shape_size == 3) {
            C = 1;
            H = output_blob->getTensorDesc().getDims().at(1);
            W = output_blob->getTensorDesc().getDims().at(2);
        } else if (output_blob_shape_size == 4) {
            C = output_blob->getTensorDesc().getDims().at(1);
            H = output_blob->getTensorDesc().getDims().at(2);
            W = output_blob->getTensorDesc().getDims().at(3);
        } else {
            throw std::logic_error("Unexpected output blob shape. Only 4D and 3D output blobs are supported.");
        }


        return nullptr;
    }
}